package org.zjvis.graph.analysis.service.algo;

import java.util.*;

public class StronglyConnectedComponents {
    private Integer TIME;   //用于对图中顶点遍历的次序进行计数
    private Integer N;
    private Integer[] DFN;    //记录图中每个节点的DFS遍历的时间戳(即次序)
    private Integer[] LOW;   //记录每个顶点的所在树的根节点编号
    private Boolean[] inStack;  //用于记录当前节点是否在栈中
    private Stack<Integer> STACK;
    private Map<Object, List<Object>> outNeighborMap;
    private Map<Object, Integer> id2index;
    private Map<Integer, Object> index2Id;

    private Map<Object, Integer> sccMap;
    private Integer sccIndex;

    public StronglyConnectedComponents(Map<Object, List<Object>> outNeighborMap) {
        TIME = 0;
        N = outNeighborMap.size();
        DFN = new Integer[N];
        LOW = new Integer[N];
        inStack = new Boolean[N];
        STACK = new Stack<>();
        this.outNeighborMap = outNeighborMap;
        id2index = new HashMap<>();
        index2Id = new HashMap<>();

        sccMap = new HashMap<>();
        sccIndex = 0;

        int index = 0;
        Iterator<Object> it = outNeighborMap.keySet().iterator();
        while (it.hasNext()) {
            Object id = it.next();
            id2index.put(id, index);
            index2Id.put(index, id);
            DFN[index] = -1;
            LOW[index] = -1;
            inStack[index] = false;
            index++;
        }
    }

    public void execute() {
        for (int i = 0; i < N; i++) {
            if (DFN[i] == -1) {
                tarjan(i);
            }
        }
    }

    public void tarjan(Integer current) {
        DFN[current] = LOW[current] = TIME++;
        STACK.push(current);
        inStack[current] = true;
        for (Object neighbor: outNeighborMap.get(index2Id.get(current))) {
            Integer next = id2index.get(neighbor);
            if (DFN[next] == -1) {
                tarjan(next);
                LOW[current] = Math.min(LOW[current], LOW[next]);
            } else if (inStack[next]) {
                LOW[current] = Math.min(LOW[current], DFN[next]);
            }
        }
        if (LOW[current].equals(DFN[current])) {
            Integer v = -1;
            while (!current.equals(v)) {
                v = STACK.pop();
                inStack[v] = false;
                sccMap.put(index2Id.get(v), sccIndex);
            }
            sccIndex++;
        }
    }

    public Map<Object, Integer> getSccMap() {
        return sccMap;
    }

    public Integer getSccCount() {
        return sccIndex;
    }

    public static void main(String[] args) throws Exception {
        Map<Object, List<Object>> outNeighborMap = new HashMap<>();
//        List<String> n1 = Arrays.asList("2", "3");
//        List<String> n2 = Arrays.asList("4");
//        List<String> n3 = Arrays.asList("4", "5");
//        List<String> n4 = Arrays.asList("1", "6");
//        List<String> n5 = Arrays.asList("6");
//        List<String> n6 = new ArrayList<>();
//        List<String> n1 = Arrays.asList("2", "3", "4");
//        List<String> n2 = Arrays.asList("1", "4");
//        List<String> n3 = Arrays.asList("1", "4", "5");
//        List<String> n4 = Arrays.asList("1", "6", "2", "3");
//        List<String> n5 = Arrays.asList("3","6");
//        List<String> n6 = Arrays.asList("4","5");
//        List<String> n7 = Arrays.asList("8");
//        List<String> n8 = Arrays.asList("7");
//        List<String> n9 = new ArrayList<>();
//        outNeighborMap.put("1", n1);
//        outNeighborMap.put("2", n2);
//        outNeighborMap.put("3", n3);
//        outNeighborMap.put("4", n4);
//        outNeighborMap.put("5", n5);
//        outNeighborMap.put("6", n6);
//        outNeighborMap.put("7", n7);
//        outNeighborMap.put("8", n8);
//        outNeighborMap.put("9", n9);

        StronglyConnectedComponents scc = new StronglyConnectedComponents(outNeighborMap);
        scc.execute();
        Map<Object, Integer> map = scc.getSccMap();
        System.out.println(map);
    }
}


